{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Binary Classification with different optimizers, schedulers, etc.\n",
    "\n",
    "In this notebook we will use the Adult Census dataset. Download the data from [here](https://www.kaggle.com/wenruliu/adult-income-dataset/downloads/adult.csv/2)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor\n",
    "from pytorch_widedeep.models import Wide, DeepDense, WideDeep\n",
    "from pytorch_widedeep.metrics import Accuracy, Recall"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>educational-num</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>gender</th>\n",
       "      <th>capital-gain</th>\n",
       "      <th>capital-loss</th>\n",
       "      <th>hours-per-week</th>\n",
       "      <th>native-country</th>\n",
       "      <th>income</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>25</td>\n",
       "      <td>Private</td>\n",
       "      <td>226802</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>89814</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Farming-fishing</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>50</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>28</td>\n",
       "      <td>Local-gov</td>\n",
       "      <td>336951</td>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>12</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Protective-serv</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&gt;50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>44</td>\n",
       "      <td>Private</td>\n",
       "      <td>160323</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>7688</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&gt;50K</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>18</td>\n",
       "      <td>?</td>\n",
       "      <td>103497</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>?</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>30</td>\n",
       "      <td>United-States</td>\n",
       "      <td>&lt;=50K</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  workclass  fnlwgt     education  educational-num      marital-status  \\\n",
       "0   25    Private  226802          11th                7       Never-married   \n",
       "1   38    Private   89814       HS-grad                9  Married-civ-spouse   \n",
       "2   28  Local-gov  336951    Assoc-acdm               12  Married-civ-spouse   \n",
       "3   44    Private  160323  Some-college               10  Married-civ-spouse   \n",
       "4   18          ?  103497  Some-college               10       Never-married   \n",
       "\n",
       "          occupation relationship   race  gender  capital-gain  capital-loss  \\\n",
       "0  Machine-op-inspct    Own-child  Black    Male             0             0   \n",
       "1    Farming-fishing      Husband  White    Male             0             0   \n",
       "2    Protective-serv      Husband  White    Male             0             0   \n",
       "3  Machine-op-inspct      Husband  Black    Male          7688             0   \n",
       "4                  ?    Own-child  White  Female             0             0   \n",
       "\n",
       "   hours-per-week native-country income  \n",
       "0              40  United-States  <=50K  \n",
       "1              50  United-States  <=50K  \n",
       "2              40  United-States   >50K  \n",
       "3              40  United-States   >50K  \n",
       "4              30  United-States  <=50K  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('data/adult/adult.csv.zip')\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>educational_num</th>\n",
       "      <th>marital_status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>gender</th>\n",
       "      <th>capital_gain</th>\n",
       "      <th>capital_loss</th>\n",
       "      <th>hours_per_week</th>\n",
       "      <th>native_country</th>\n",
       "      <th>income_label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>25</td>\n",
       "      <td>Private</td>\n",
       "      <td>226802</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>89814</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Farming-fishing</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>50</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>28</td>\n",
       "      <td>Local-gov</td>\n",
       "      <td>336951</td>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>12</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Protective-serv</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>44</td>\n",
       "      <td>Private</td>\n",
       "      <td>160323</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>7688</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>18</td>\n",
       "      <td>?</td>\n",
       "      <td>103497</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>?</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>30</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  workclass  fnlwgt     education  educational_num      marital_status  \\\n",
       "0   25    Private  226802          11th                7       Never-married   \n",
       "1   38    Private   89814       HS-grad                9  Married-civ-spouse   \n",
       "2   28  Local-gov  336951    Assoc-acdm               12  Married-civ-spouse   \n",
       "3   44    Private  160323  Some-college               10  Married-civ-spouse   \n",
       "4   18          ?  103497  Some-college               10       Never-married   \n",
       "\n",
       "          occupation relationship   race  gender  capital_gain  capital_loss  \\\n",
       "0  Machine-op-inspct    Own-child  Black    Male             0             0   \n",
       "1    Farming-fishing      Husband  White    Male             0             0   \n",
       "2    Protective-serv      Husband  White    Male             0             0   \n",
       "3  Machine-op-inspct      Husband  Black    Male          7688             0   \n",
       "4                  ?    Own-child  White  Female             0             0   \n",
       "\n",
       "   hours_per_week native_country  income_label  \n",
       "0              40  United-States             0  \n",
       "1              50  United-States             0  \n",
       "2              40  United-States             1  \n",
       "3              40  United-States             1  \n",
       "4              30  United-States             0  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# For convenience, we'll replace '-' with '_'\n",
    "df.columns = [c.replace(\"-\", \"_\") for c in df.columns]\n",
    "# binary target\n",
    "df['income_label'] = (df[\"income\"].apply(lambda x: \">50K\" in x)).astype(int)\n",
    "df.drop('income', axis=1, inplace=True)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preparing the data\n",
    "\n",
    "Have a look to notebooks one and two if you want to get a good understanding of the next few lines of code (although there is no need to use the package)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "wide_cols = ['education', 'relationship','workclass','occupation','native_country','gender']\n",
    "crossed_cols = [('education', 'occupation'), ('native_country', 'occupation')]\n",
    "cat_embed_cols = [('education',16), ('relationship',8), ('workclass',16), ('occupation',16),('native_country',16)]\n",
    "continuous_cols = [\"age\",\"hours_per_week\"]\n",
    "target_col = 'income_label'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TARGET\n",
    "target = df[target_col].values\n",
    "\n",
    "# WIDE\n",
    "preprocess_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)\n",
    "X_wide = preprocess_wide.fit_transform(df)\n",
    "\n",
    "# DEEP\n",
    "preprocess_deep = DensePreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)\n",
    "X_deep = preprocess_deep.fit_transform(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[  1  17  23 ...  89  91 316]\n",
      " [  2  18  23 ...  89  92 317]\n",
      " [  3  18  24 ...  89  93 318]\n",
      " ...\n",
      " [  2  20  23 ...  90 103 323]\n",
      " [  2  17  23 ...  89 103 323]\n",
      " [  2  21  29 ...  90 115 324]]\n",
      "(48842, 8)\n"
     ]
    }
   ],
   "source": [
    "print(X_wide)\n",
    "print(X_wide.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 0.          0.          0.         ...  0.         -0.99512893\n",
      "  -0.03408696]\n",
      " [ 1.          1.          0.         ...  0.         -0.04694151\n",
      "   0.77292975]\n",
      " [ 2.          1.          1.         ...  0.         -0.77631645\n",
      "  -0.03408696]\n",
      " ...\n",
      " [ 1.          3.          0.         ...  0.          1.41180837\n",
      "  -0.03408696]\n",
      " [ 1.          0.          0.         ...  0.         -1.21394141\n",
      "  -1.64812038]\n",
      " [ 1.          4.          6.         ...  0.          0.97418341\n",
      "  -0.03408696]]\n",
      "(48842, 7)\n"
     ]
    }
   ],
   "source": [
    "print(X_deep)\n",
    "print(X_deep.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, you can run a wide and deep model in just a few lines of code\n",
    "\n",
    "Let's now see how to use `WideDeep` with varying parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  2.1 Dropout and Batchnorm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)\n",
    "# We can add dropout and batchnorm to the dense layers\n",
    "deepdense = DeepDense(hidden_layers=[64,32], dropout=[0.5, 0.5], batchnorm=True,\n",
    "                      deep_column_idx=preprocess_deep.deep_column_idx,\n",
    "                      embed_input=preprocess_deep.embeddings_input,\n",
    "                      continuous_cols=continuous_cols)\n",
    "model = WideDeep(wide=wide, deepdense=deepdense)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "WideDeep(\n",
       "  (wide): Wide(\n",
       "    (wide_linear): Embedding(797, 1, padding_idx=0)\n",
       "  )\n",
       "  (deepdense): Sequential(\n",
       "    (0): DeepDense(\n",
       "      (embed_layers): ModuleDict(\n",
       "        (emb_layer_education): Embedding(17, 16)\n",
       "        (emb_layer_native_country): Embedding(43, 16)\n",
       "        (emb_layer_occupation): Embedding(16, 16)\n",
       "        (emb_layer_relationship): Embedding(7, 8)\n",
       "        (emb_layer_workclass): Embedding(10, 16)\n",
       "      )\n",
       "      (embed_dropout): Dropout(p=0.0, inplace=False)\n",
       "      (dense): Sequential(\n",
       "        (dense_layer_0): Sequential(\n",
       "          (0): Linear(in_features=74, out_features=64, bias=True)\n",
       "          (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "          (2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (3): Dropout(p=0.5, inplace=False)\n",
       "        )\n",
       "        (dense_layer_1): Sequential(\n",
       "          (0): Linear(in_features=64, out_features=32, bias=True)\n",
       "          (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "          (2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (3): Dropout(p=0.5, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (1): Linear(in_features=32, out_features=1, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use different initializers, optimizers and learning rate schedulers for each `branch` of the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  Optimizers, LR schedulers, Initializers and Callbacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_widedeep.initializers import KaimingNormal, XavierNormal\n",
    "from pytorch_widedeep.callbacks import ModelCheckpoint, LRHistory, EarlyStopping\n",
    "from pytorch_widedeep.optim import RAdam"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optimizers\n",
    "wide_opt = torch.optim.Adam(model.wide.parameters(), lr=0.03)\n",
    "deep_opt = RAdam(model.deepdense.parameters(), lr=0.01)\n",
    "# LR Schedulers\n",
    "wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=3)\n",
    "deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "the component-dependent settings must be passed as dictionaries, while general settings are simply lists"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Component-dependent settings as Dict\n",
    "optimizers = {'wide': wide_opt, 'deepdense':deep_opt}\n",
    "schedulers = {'wide': wide_sch, 'deepdense':deep_sch}\n",
    "initializers = {'wide': KaimingNormal, 'deepdense':XavierNormal}\n",
    "# General settings as List\n",
    "callbacks = [LRHistory(n_epochs=10), EarlyStopping, ModelCheckpoint(filepath='model_weights/wd_out')]\n",
    "metrics = [Accuracy, Recall]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(method='binary', optimizers=optimizers, lr_schedulers=schedulers, \n",
    "              initializers=initializers,\n",
    "              callbacks=callbacks,\n",
    "              metrics=metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/153 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████| 153/153 [00:02<00:00, 72.33it/s, loss=0.503, metrics={'acc': 0.7885, 'rec': 0.4864}]\n",
      "valid: 100%|██████████| 39/39 [00:00<00:00, 127.72it/s, loss=0.386, metrics={'acc': 0.7962, 'rec': 0.4998}]\n",
      "epoch 2: 100%|██████████| 153/153 [00:02<00:00, 71.76it/s, loss=0.374, metrics={'acc': 0.8268, 'rec': 0.5242}]\n",
      "valid: 100%|██████████| 39/39 [00:00<00:00, 126.72it/s, loss=0.372, metrics={'acc': 0.8277, 'rec': 0.5281}]\n",
      "epoch 3: 100%|██████████| 153/153 [00:02<00:00, 73.21it/s, loss=0.367, metrics={'acc': 0.8298, 'rec': 0.5242}]\n",
      "valid: 100%|██████████| 39/39 [00:00<00:00, 126.68it/s, loss=0.37, metrics={'acc': 0.8303, 'rec': 0.5279}]\n",
      "epoch 4: 100%|██████████| 153/153 [00:02<00:00, 71.37it/s, loss=0.36, metrics={'acc': 0.8319, 'rec': 0.5372}] \n",
      "valid: 100%|██████████| 39/39 [00:00<00:00, 128.64it/s, loss=0.369, metrics={'acc': 0.8324, 'rec': 0.5412}]\n",
      "epoch 5: 100%|██████████| 153/153 [00:02<00:00, 71.53it/s, loss=0.359, metrics={'acc': 0.8322, 'rec': 0.5378}]\n",
      "valid: 100%|██████████| 39/39 [00:00<00:00, 119.31it/s, loss=0.369, metrics={'acc': 0.8325, 'rec': 0.5412}]\n",
      "epoch 6: 100%|██████████| 153/153 [00:02<00:00, 71.37it/s, loss=0.359, metrics={'acc': 0.8322, 'rec': 0.5361}]\n",
      "valid: 100%|██████████| 39/39 [00:00<00:00, 125.99it/s, loss=0.369, metrics={'acc': 0.8326, 'rec': 0.5398}]\n",
      "epoch 7: 100%|██████████| 153/153 [00:02<00:00, 70.20it/s, loss=0.358, metrics={'acc': 0.8329, 'rec': 0.5396}]\n",
      "valid: 100%|██████████| 39/39 [00:00<00:00, 124.88it/s, loss=0.369, metrics={'acc': 0.8331, 'rec': 0.5416}]\n",
      "epoch 8: 100%|██████████| 153/153 [00:02<00:00, 70.75it/s, loss=0.358, metrics={'acc': 0.833, 'rec': 0.5374}] \n",
      "valid: 100%|██████████| 39/39 [00:00<00:00, 125.81it/s, loss=0.369, metrics={'acc': 0.8331, 'rec': 0.5397}]\n",
      "epoch 9: 100%|██████████| 153/153 [00:02<00:00, 70.40it/s, loss=0.358, metrics={'acc': 0.833, 'rec': 0.5368}] \n",
      "valid: 100%|██████████| 39/39 [00:00<00:00, 125.07it/s, loss=0.369, metrics={'acc': 0.8331, 'rec': 0.5391}]\n",
      "epoch 10: 100%|██████████| 153/153 [00:02<00:00, 70.20it/s, loss=0.358, metrics={'acc': 0.8329, 'rec': 0.537}] \n",
      "valid: 100%|██████████| 39/39 [00:00<00:00, 124.43it/s, loss=0.369, metrics={'acc': 0.8331, 'rec': 0.5392}]\n"
     ]
    }
   ],
   "source": [
    "model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=10, batch_size=256, val_split=0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['__call__',\n",
       " '__class__',\n",
       " '__delattr__',\n",
       " '__dict__',\n",
       " '__dir__',\n",
       " '__doc__',\n",
       " '__eq__',\n",
       " '__format__',\n",
       " '__ge__',\n",
       " '__getattr__',\n",
       " '__getattribute__',\n",
       " '__gt__',\n",
       " '__hash__',\n",
       " '__init__',\n",
       " '__init_subclass__',\n",
       " '__le__',\n",
       " '__lt__',\n",
       " '__module__',\n",
       " '__ne__',\n",
       " '__new__',\n",
       " '__reduce__',\n",
       " '__reduce_ex__',\n",
       " '__repr__',\n",
       " '__setattr__',\n",
       " '__setstate__',\n",
       " '__sizeof__',\n",
       " '__str__',\n",
       " '__subclasshook__',\n",
       " '__weakref__',\n",
       " '_apply',\n",
       " '_backward_hooks',\n",
       " '_buffers',\n",
       " '_forward_hooks',\n",
       " '_forward_pre_hooks',\n",
       " '_get_name',\n",
       " '_load_from_state_dict',\n",
       " '_load_state_dict_pre_hooks',\n",
       " '_loss_fn',\n",
       " '_lr_scheduler_step',\n",
       " '_modules',\n",
       " '_named_members',\n",
       " '_parameters',\n",
       " '_predict',\n",
       " '_register_load_state_dict_pre_hook',\n",
       " '_register_state_dict_hook',\n",
       " '_replicate_for_data_parallel',\n",
       " '_save_to_state_dict',\n",
       " '_slow_forward',\n",
       " '_state_dict_hooks',\n",
       " '_train_val_split',\n",
       " '_training_step',\n",
       " '_validation_step',\n",
       " '_version',\n",
       " '_warm_up',\n",
       " 'add_module',\n",
       " 'apply',\n",
       " 'batch_size',\n",
       " 'bfloat16',\n",
       " 'buffers',\n",
       " 'callback_container',\n",
       " 'callbacks',\n",
       " 'children',\n",
       " 'class_weight',\n",
       " 'compile',\n",
       " 'cpu',\n",
       " 'cuda',\n",
       " 'cyclic',\n",
       " 'deepdense',\n",
       " 'deephead',\n",
       " 'deepimage',\n",
       " 'deeptext',\n",
       " 'double',\n",
       " 'dump_patches',\n",
       " 'early_stop',\n",
       " 'eval',\n",
       " 'extra_repr',\n",
       " 'fit',\n",
       " 'float',\n",
       " 'forward',\n",
       " 'get_embeddings',\n",
       " 'half',\n",
       " 'history',\n",
       " 'initializer',\n",
       " 'load_state_dict',\n",
       " 'lr_history',\n",
       " 'lr_scheduler',\n",
       " 'method',\n",
       " 'metric',\n",
       " 'modules',\n",
       " 'named_buffers',\n",
       " 'named_children',\n",
       " 'named_modules',\n",
       " 'named_parameters',\n",
       " 'optimizer',\n",
       " 'parameters',\n",
       " 'predict',\n",
       " 'predict_proba',\n",
       " 'register_backward_hook',\n",
       " 'register_buffer',\n",
       " 'register_forward_hook',\n",
       " 'register_forward_pre_hook',\n",
       " 'register_parameter',\n",
       " 'requires_grad_',\n",
       " 'seed',\n",
       " 'share_memory',\n",
       " 'state_dict',\n",
       " 'to',\n",
       " 'train',\n",
       " 'train_running_loss',\n",
       " 'training',\n",
       " 'transforms',\n",
       " 'type',\n",
       " 'valid_running_loss',\n",
       " 'verbose',\n",
       " 'wide',\n",
       " 'with_focal_loss',\n",
       " 'zero_grad']"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dir(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You see that, among many methods and attributes we have the `history` and `lr_history` attributes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.history.epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'train_loss': [0.5026861273385341, 0.37383826573689777, 0.36658557158669614, 0.3601557047538508, 0.3594148938172783, 0.35907501001763187, 0.358282413942362, 0.35823015644659406, 0.35819698957835927, 0.3581014702133104], 'train_acc': [0.788549637857344, 0.8267857599877153, 0.8297545619737414, 0.8318787909809843, 0.8321859084278146, 0.832237094668953, 0.832902515803752, 0.8329792951654595, 0.8329537020448904, 0.8329281089243211], 'train_rec': [0.48636218905448914, 0.5242272019386292, 0.5242272019386292, 0.5371697545051575, 0.5378115177154541, 0.5361000895500183, 0.5396299362182617, 0.5373836755752563, 0.5368488430976868, 0.5369558334350586], 'val_loss': [0.38589231249613637, 0.371902360365941, 0.36999432627971357, 0.36935041348139447, 0.3691598016482133, 0.36905216712218064, 0.36900061674607104, 0.36898223635477895, 0.36896658937136334, 0.36896434120642835], 'val_acc': [0.79624094017444, 0.8277302321772245, 0.8302895049342779, 0.832357397321977, 0.8325211907784285, 0.8325826133245977, 0.833073993693952, 0.8331354162401212, 0.8330944678760084, 0.833073993693952], 'val_rec': [0.4997860789299011, 0.5281081795692444, 0.5279369950294495, 0.5411996245384216, 0.5411996245384216, 0.5398305654525757, 0.5416274666786194, 0.5396594405174255, 0.5391460657119751, 0.5392315983772278]}\n"
     ]
    }
   ],
   "source": [
    "print(model.history._history)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'lr_wide_0': [0.03, 0.03, 0.03, 0.003, 0.003, 0.003, 0.00030000000000000003, 0.00030000000000000003, 0.00030000000000000003, 3.0000000000000004e-05], 'lr_deepdense_0': [0.01, 0.01, 0.01, 0.01, 0.01, 0.001, 0.001, 0.001, 0.001, 0.001]}\n"
     ]
    }
   ],
   "source": [
    "print(model.lr_history)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the learning rate effectively decreases by a factor of 0.1 (the default) after the corresponding `step_size`. Note that the keys of the dictionary have a suffix `_0`. This is because if you pass different parameter groups to the torch optimizers, these will also be recorded. We'll see this in the `Regression` notebook. \n",
    "\n",
    "And I guess one has a good idea of how to use the package. Before we leave this notebook just mentioning that the `WideDeep` class comes with a useful method to \"rescue\" the learned embeddings. For example, let's say I want to use the embeddings learned for the different levels of the categorical feature `education`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'11th': array([ 0.33238176,  0.02123132,  0.42671534, -0.16836806,  0.04070434,\n",
       "         0.21476945, -0.05866506,  0.09599391,  0.21264766, -0.08261641,\n",
       "        -0.4364204 ,  0.5176953 , -0.17785792,  0.1990719 ,  0.05055304,\n",
       "        -0.05390744], dtype=float32),\n",
       " 'HS-grad': array([ 0.1851779 , -0.0601109 , -0.04134565, -0.17099169,  0.01647249,\n",
       "         0.1691518 , -0.03775224, -0.01711482, -0.13714994, -0.02202759,\n",
       "        -0.2350222 ,  0.20368417,  0.06420711,  0.08465873,  0.11443923,\n",
       "        -0.28585908], dtype=float32),\n",
       " 'Assoc-acdm': array([-0.2891686 , -0.25329128, -0.03977084,  0.34204823,  0.4393897 ,\n",
       "         0.24583909, -0.08771466,  0.3398704 ,  0.06197336, -0.09200054,\n",
       "         0.13266966, -0.27940965, -0.10639463,  0.16516595,  0.20191231,\n",
       "        -0.11804624], dtype=float32),\n",
       " 'Some-college': array([ 0.17284533, -0.34509236, -0.22175975, -0.11192639,  0.14154772,\n",
       "         0.04188053,  0.14860624,  0.28312132,  0.06071718, -0.10315312,\n",
       "        -0.05902205, -0.03197744,  0.20363455,  0.04027565,  0.43063605,\n",
       "         0.21163562], dtype=float32),\n",
       " '10th': array([ 0.13888928,  0.28386956,  0.18166119,  0.02652328,  0.11637231,\n",
       "         0.24056876, -0.06386037,  0.05930374,  0.04393852,  0.17677549,\n",
       "         0.27980283, -0.01221516,  0.12281907,  0.04273703,  0.22282158,\n",
       "        -0.25718638], dtype=float32),\n",
       " 'Prof-school': array([ 0.26996085,  0.06557842,  0.0957497 ,  0.06524102,  0.05351401,\n",
       "         0.34774455, -0.39007127, -0.35276353, -0.19460988,  0.06306136,\n",
       "        -0.03555794,  0.02946662,  0.47177076,  0.21887466,  0.34440616,\n",
       "         0.17761633], dtype=float32),\n",
       " '7th-8th': array([-0.14013144, -0.20337081,  0.6704599 , -0.10210201,  0.1633953 ,\n",
       "        -0.03677108, -0.04664218, -0.13967332, -0.02610652, -0.15920916,\n",
       "        -0.18137608, -0.01846946,  0.35807863,  0.0148629 ,  0.2857368 ,\n",
       "         0.28930005], dtype=float32),\n",
       " 'Bachelors': array([-0.38666266,  0.17745058, -0.6287257 ,  0.22080924,  0.25037012,\n",
       "        -0.10224682,  0.5612052 , -0.24709803,  0.03214271, -0.22835065,\n",
       "        -0.14132145,  0.3010941 , -0.23835489,  0.08622   , -0.04518703,\n",
       "         0.31074366], dtype=float32),\n",
       " 'Masters': array([-0.41403466, -0.33947882,  0.14072244, -0.22146806, -0.18230349,\n",
       "        -0.1195543 , -0.84759206,  0.25256675,  0.14532281, -0.01060636,\n",
       "        -0.03578382, -0.07117725,  0.10634375, -0.11669173,  0.17765476,\n",
       "        -0.03559739], dtype=float32),\n",
       " 'Doctorate': array([ 0.00375404, -0.02784416, -0.28326795,  0.22763273,  0.03977633,\n",
       "         0.2893272 ,  0.25680798,  0.36434892, -0.65951985, -0.23679003,\n",
       "        -0.11408209, -0.23283346, -0.27024168,  0.0655888 , -0.28381783,\n",
       "        -0.01525949], dtype=float32),\n",
       " '5th-6th': array([ 0.00683184,  0.23564084, -0.132059  , -0.3406017 , -0.06710123,\n",
       "        -0.09649926,  0.50411046, -0.12363172, -0.0353502 , -0.53238744,\n",
       "        -0.05181202, -0.05146485, -0.23931046, -0.26453286,  0.08420272,\n",
       "         0.0235041 ], dtype=float32),\n",
       " 'Assoc-voc': array([ 0.01930698, -0.2455314 ,  0.2246628 ,  0.16216752, -0.4528598 ,\n",
       "        -0.6121017 ,  0.15893641,  0.01993939, -0.3148845 ,  0.03837916,\n",
       "         0.0767131 , -0.36453167,  0.19929656,  0.28016493,  0.29385152,\n",
       "        -0.47822088], dtype=float32),\n",
       " '9th': array([-0.03110321,  0.69687057, -0.33127317,  0.06741869,  0.08373164,\n",
       "         0.25090563,  0.07099659,  0.21758935, -0.07414749, -0.19316533,\n",
       "         0.21613942,  0.28149685, -0.41364396, -0.0439614 , -0.02726781,\n",
       "        -0.04664526], dtype=float32),\n",
       " '12th': array([ 0.46782094,  0.1987633 ,  0.11554655, -0.23237073, -0.35828865,\n",
       "        -0.08366812,  0.0086338 ,  0.46672872, -0.24939838,  0.22630745,\n",
       "        -0.16754937, -0.4713689 , -0.08152255,  0.02004629,  0.1118032 ,\n",
       "         0.20979449], dtype=float32),\n",
       " '1st-4th': array([-0.16926417,  0.11347993,  0.02692448, -0.10284851,  0.25171363,\n",
       "        -0.04539176, -0.24491136,  0.3281045 , -0.08861455,  0.18578447,\n",
       "         0.23892452, -0.00729677,  0.16713212,  0.2949316 , -0.00725389,\n",
       "        -0.20607162], dtype=float32),\n",
       " 'Preschool': array([-0.30532706,  0.25465214, -0.5603218 , -0.16249408, -0.32321507,\n",
       "         0.11698078,  0.01557691, -0.3124683 , -0.25044286,  0.08334377,\n",
       "         0.2094927 ,  0.03301949, -0.01236501, -0.24443303, -0.03395106,\n",
       "        -0.01797807], dtype=float32),\n",
       " 'unseen': array([-0.17771505,  0.3246768 , -0.29062387,  0.12164559,  0.34164497,\n",
       "        -0.5451506 ,  0.22189835,  0.21224639,  0.4933099 , -0.03533744,\n",
       "        -0.12335563,  0.12472781,  0.1412489 ,  0.17336178,  0.4160364 ,\n",
       "        -0.32417113], dtype=float32)}"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.get_embeddings(col_name='education', cat_encoding_dict=preprocess_deep.label_encoder.encoding_dict)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
